{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DBSCAN算法的Spark实现"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一，实现思路"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DBSCAN算法的分布式实现需要解决以下一些主要的问题。\n",
    "\n",
    "\n",
    "**1，如何计算样本点中两两之间的距离？**\n",
    "\n",
    "在单机环境下，计算样本点两两之间的距离比较简单，是一个双重遍历的过程。\n",
    "为了减少计算量，可以用空间索引如Rtree进行加速。\n",
    "\n",
    "在分布式环境，样本点分布在不同的分区，难以在不同的分区之间直接进行双重遍历。\n",
    "为了解决这个问题，我的方案是将样本点不同的分区分成多个批次拉到Driver端，\n",
    "然后依次广播到各个excutor分别计算距离，将最终结果union，从而间接实现双重遍历。\n",
    "\n",
    "为了减少计算量，广播前对拉到Driver端的数据构建空间索引Rtree进行加速。\n",
    "\n",
    "\n",
    "\n",
    "**2，如何构造临时聚类簇？**\n",
    "\n",
    "这个问题不难，单机环境和分布式环境的实现差不多。\n",
    "\n",
    "都是通过group的方式统计每个样本点周边邻域半径R内的样本点数量，\n",
    "\n",
    "并记录它们的id,如果这些样本点数量超过minpoints则构造临时聚类簇，并维护核心点列表。\n",
    "\n",
    "\n",
    "**3，如何合并相连的临时聚类簇得到聚类簇？**\n",
    "\n",
    "这个是分布式实现中最最核心的步骤。\n",
    "\n",
    "在单机环境下，标准做法是对每一个临时聚类簇，\n",
    "\n",
    "判断其中的样本点是否在核心点列表，如果是，则将该样本点所在的临时聚类簇与当前临时聚类簇合并。并在核心点列表中删除该样本点。\n",
    "\n",
    "重复此过程，直到当前临时聚类簇中所有的点都不在核心点列表。\n",
    "\n",
    "在分布式环境下，临时聚类簇分布在不同的分区，无法直接扫描全局核心点列表进行临时聚类簇的合并。\n",
    "\n",
    "我的方案是先在每一个分区内部对各个临时聚类簇进行合并，然后缩小分区数量重新分区，再在各个分区内部对每个临时聚类簇进行合并。\n",
    "\n",
    "不断重复这个过程，最终将所有的临时聚类簇都划分到一个分区，完成对全部临时聚类簇的合并。\n",
    "\n",
    "为了降低最后一个分区的存储压力，我采用了不同于标准的临时聚类簇的合并算法。\n",
    "\n",
    "对每个临时聚类簇只关注其中的核心点id,而不关注非核心点id,以减少存储压力。合并时将有共同核心点id的临时聚类簇合并。\n",
    "\n",
    "为了加快临时聚类的合并过程，分区时并非随机分区，而是以每个临时聚类簇的核心点id中的最小值min_core_id作为分区的Hash参数，\n",
    "\n",
    "具有共同核心点id的临时聚类簇有更大的概率被划分到同一个分区，从而加快了合并过程。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./data/DBSCAN算法步骤.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二，核心代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "spark = org.apache.spark.sql.SparkSession@51a18d93\n",
       "sc = org.apache.spark.SparkContext@35a42a03\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "org.apache.spark.SparkContext@35a42a03"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import org.apache.spark.sql.SparkSession\n",
    "\n",
    "val spark = SparkSession\n",
    ".builder()\n",
    ".appName(\"dbscan\")\n",
    ".getOrCreate()\n",
    "\n",
    "val sc = spark.sparkContext\n",
    "import spark.implicits._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**1，寻找核心点形成临时聚类簇。**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "该步骤一般要采用空间索引 + 广播的方法，此处从略，假定已经得到了临时聚类簇。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1,Set(1, 2))\n",
      "(2,Set(2, 3, 4))\n",
      "(6,Set(6, 8, 9))\n",
      "(4,Set(4, 5))\n",
      "(9,Set(9, 10, 11))\n",
      "(15,Set(15, 17))\n",
      "(10,Set(10, 11, 18))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "rdd_core = ParallelCollectionRDD[1] at parallelize at <console>:34\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "ParallelCollectionRDD[1] at parallelize at <console>:34"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "//rdd_core的每一行代表一个临时聚类簇：(min_core_id, core_id_set)\n",
    "//core_id_set为临时聚类簇所有核心点的编号，min_core_id为这些编号中取值最小的编号\n",
    "var rdd_core = sc.parallelize(List((1L,Set(1L,2L)),(2L,Set(2L,3L,4L)),\n",
    "                                       (6L,Set(6L,8L,9L)),(4L,Set(4L,5L)),\n",
    "                                       (9L,Set(9L,10L,11L)),(15L,Set(15L,17L)),\n",
    "                                       (10L,Set(10L,11L,18L))))\n",
    "rdd_core.collect.foreach(println)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](data/dbscan核心算法的输入.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**2，合并临时聚类簇得到聚类簇。**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "mergeSets = > scala.collection.mutable.ListBuffer[Set[Long]] = <function1>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "mergeRDD: (rdd: org.apache.spark.rdd.RDD[(Long, Set[Long])], partition_cnt: Int)org.apache.spark.rdd.RDD[(Long, Set[Long])]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "> scala.collection.mutable.ListBuffer[Set[Long]] = <function1>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import scala.collection.mutable.ListBuffer\n",
    "import org.apache.spark.HashPartitioner\n",
    "\n",
    "//定义合并函数：将有共同核心点的临时聚类簇合并\n",
    "val mergeSets = (set_list: ListBuffer[Set[Long]]) =>{\n",
    "  var result = ListBuffer[Set[Long]]()\n",
    "  while (set_list.size>0){\n",
    "    var cur_set = set_list.remove(0)\n",
    "    var intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)\n",
    "    while(intersect_idxs.size>0){\n",
    "      for(idx<-intersect_idxs){\n",
    "        cur_set = cur_set|set_list(idx)\n",
    "      }\n",
    "      for(idx<-intersect_idxs){\n",
    "        set_list.remove(idx)\n",
    "      }\n",
    "      intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)\n",
    "    }\n",
    "    result = result:+cur_set\n",
    "  }\n",
    "  result\n",
    "}\n",
    "\n",
    "///对rdd_core分区后在每个分区合并，不断将分区数量减少，最终合并到一个分区\n",
    "//如果数据规模十分大，难以合并到一个分区，也可以最终合并到多个分区，得到近似结果。\n",
    "//rdd: (min_core_id,core_id_set)\n",
    "\n",
    "def mergeRDD(rdd: org.apache.spark.rdd.RDD[(Long,Set[Long])], partition_cnt:Int):\n",
    "org.apache.spark.rdd.RDD[(Long,Set[Long])] = {\n",
    "  val rdd_merged =  rdd.partitionBy(new HashPartitioner(partition_cnt))\n",
    "    .mapPartitions(iter => {\n",
    "      val buffer = ListBuffer[Set[Long]]()\n",
    "      for(t<-iter){\n",
    "        val core_id_set:Set[Long] = t._2\n",
    "        buffer.append(core_id_set)\n",
    "      }\n",
    "      val merged_buffer = mergeSets(buffer)\n",
    "      var result = List[(Long,Set[Long])]()\n",
    "      for(core_id_set<-merged_buffer){\n",
    "        val min_core_id = core_id_set.min\n",
    "        result = result:+(min_core_id,core_id_set)\n",
    "      }\n",
    "      result.iterator\n",
    "    })\n",
    "  rdd_merged\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1,Set(5, 1, 2, 3, 4))\n",
      "(6,Set(10, 6, 9, 18, 11, 8))\n",
      "(15,Set(15, 17))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "rdd_core = MapPartitionsRDD[7] at mapPartitions at <console>:63\n",
       "rdd_core = MapPartitionsRDD[7] at mapPartitions at <console>:63\n",
       "rdd_core = MapPartitionsRDD[7] at mapPartitions at <console>:63\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "MapPartitionsRDD[7] at mapPartitions at <console>:63"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "//分区迭代计算，可以根据需要调整迭代次数和分区数量\n",
    "rdd_core = mergeRDD(rdd_core,8)\n",
    "rdd_core = mergeRDD(rdd_core,4)\n",
    "rdd_core = mergeRDD(rdd_core,1)\n",
    "rdd_core.collect.foreach(println)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](data/dbscan核心算法的输出.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三，完整范例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以下为完整范例代码，使用了和《20分钟学会DBSCAN聚类算法》文中完全一样的数据源和参数，并且得到了完全一样的结果。\n",
    "\n",
    "不同的是，以下代码是一种基于Spark的分布式实现，可以很好地扩展到大数据集上。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "spark = org.apache.spark.sql.SparkSession@51a18d93\n",
       "sc = org.apache.spark.SparkContext@35a42a03\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "printlog: (info: String)Unit\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "org.apache.spark.SparkContext@35a42a03"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import org.apache.spark.sql.SparkSession\n",
    "import org.apache.spark.storage.StorageLevel\n",
    "import org.apache.spark.sql.{DataFrame, Row, Column}\n",
    "import org.apache.spark.sql.functions._\n",
    "import org.locationtech.jts.geom.{Geometry,GeometryFactory,Coordinate,Point}\n",
    "import org.locationtech.jts.index.strtree.STRtree\n",
    "import org.apache.spark.sql.jts.registerTypes\n",
    "import scala.collection.mutable.WrappedArray\n",
    "import scala.collection.JavaConversions._\n",
    "\n",
    "def printlog(info:String): Unit ={\n",
    "    val dt = new java.text.SimpleDateFormat(\"yyyy-MM-dd HH:mm:ss\").format(new java.util.Date)\n",
    "    println(\"==========\"*8+dt)\n",
    "    println(info+\"\\n\")\n",
    "}\n",
    "\n",
    "val spark = SparkSession\n",
    "    .builder()\n",
    "    .appName(\"dbscan\")\n",
    "    .enableHiveSupport()\n",
    "    .getOrCreate()\n",
    "\n",
    "val sc = spark.sparkContext\n",
    "import spark.implicits._\n",
    "registerTypes\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================2019-11-24 18:06:12\n",
      "step1: get input data -> dfinput ...\n",
      "\n",
      "+---+--------------------+\n",
      "| id|            geometry|\n",
      "+---+--------------------+\n",
      "|  0|POINT (0.31655567...|\n",
      "|  1|POINT (0.74088269...|\n",
      "|  2|POINT (0.87172637...|\n",
      "|  3|POINT (0.55552787...|\n",
      "|  4|POINT (2.03872887...|\n",
      "|  5|POINT (1.99136342...|\n",
      "|  6|POINT (0.22384428...|\n",
      "|  7|POINT (0.97295674...|\n",
      "|  8|POINT (-0.9213036...|\n",
      "|  9|POINT (0.46670632...|\n",
      "| 10|POINT (0.49217803...|\n",
      "| 11|POINT (-0.4223529...|\n",
      "| 12|POINT (0.31358610...|\n",
      "| 13|POINT (0.64848081...|\n",
      "| 14|POINT (0.31549460...|\n",
      "| 15|POINT (-0.9118786...|\n",
      "| 16|POINT (1.70164131...|\n",
      "| 17|POINT (0.10851453...|\n",
      "| 18|POINT (-0.3098724...|\n",
      "| 19|POINT (-0.2040816...|\n",
      "+---+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "dfdata = [feature1: double, feature2: double]\n",
       "dfinput = [id: bigint, geometry: point]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[id: bigint, geometry: point]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "/*================================================================================*/\n",
    "//  一，读入数据 dfinput\n",
    "/*================================================================================*/\n",
    "printlog(\"step1: get input data -> dfinput ...\")\n",
    "\n",
    "val dfdata = spark.read.option(\"header\",\"true\") \n",
    "  .option(\"inferSchema\",\"true\") \n",
    "  .option(\"delimiter\", \"\\t\") \n",
    "  .csv(\"data/moon_points.csv\")\n",
    "\n",
    "spark.udf.register(\"makePoint\", (x:Double,y:Double) =>{\n",
    "    val gf = new GeometryFactory\n",
    "    val pt = gf.createPoint(new Coordinate(x,y))\n",
    "    pt\n",
    "})\n",
    "\n",
    "val dfinput = dfdata.selectExpr(\"makePoint(feature1,feature2) as point\")\n",
    "            .rdd.map(row=>row.getAs[Point](\"point\"))\n",
    "            .zipWithIndex().toDF(\"geometry\",\"id\").selectExpr(\"id\",\"geometry\")\n",
    "            .persist(StorageLevel.MEMORY_AND_DISK)\n",
    "\n",
    "dfinput.show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================2019-11-24 18:06:26\n",
      "step2: looking for neighbours by broadcasting Rtree -> dfnear ...\n",
      "\n",
      "+-----+-----+--------------------+\n",
      "|s_fid|m_fid|              s_geom|\n",
      "+-----+-----+--------------------+\n",
      "|   19|  271|POINT (-0.2040816...|\n",
      "|   19|  489|POINT (-0.2040816...|\n",
      "|   19|  488|POINT (-0.2040816...|\n",
      "+-----+-----+--------------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "partition_cnt = 10\n",
       "rdd_input = MapPartitionsRDD[40] at repartition at <console>:64\n",
       "dfbuffer = [id: bigint, envelop: geometry]\n",
       "dfnear = [s_fid: bigint, m_fid: bigint ... 1 more field]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[s_fid: bigint, m_fid: bigint ... 1 more field]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "/*================================================================================*/\n",
    "//  二，分批次广播RTree得到邻近关系 dfnear\n",
    "/*================================================================================*/\n",
    "printlog(\"step2: looking for neighbours by broadcasting Rtree -> dfnear ...\")\n",
    "\n",
    "spark.udf.register(\"getBufferBox\", (p: Point) => p.getEnvelope.buffer(0.2).getEnvelope)\n",
    "\n",
    "//分批次进行广播\n",
    "val partition_cnt = 10\n",
    "val rdd_input = dfinput.rdd.repartition(20).persist(StorageLevel.MEMORY_AND_DISK)\n",
    "val dfbuffer = dfinput.selectExpr(\"id\",\"getBufferBox(geometry) as envelop\").repartition(partition_cnt)\n",
    "var dfnear = List[(Long,Long,Point)]().toDF(\"s_fid\",\"m_fid\",\"s_geom\")\n",
    "\n",
    "\n",
    "for(partition_id <- 0 until partition_cnt){\n",
    "  val bufferi = dfbuffer.rdd.mapPartitionsWithIndex(\n",
    "    (idx, iter) => if (idx == partition_id ) iter else Iterator())\n",
    "  val Rtree = new STRtree()\n",
    "  bufferi.collect.foreach(x => Rtree.insert(x.getAs[Geometry](\"envelop\").getEnvelopeInternal, x))\n",
    "  val tree_broads = sc.broadcast(Rtree)\n",
    "\n",
    "  val dfneari = rdd_input.mapPartitions(iter => {\n",
    "    var res_list = List[(Long,Long,Point)]()//s_fid,m_fid,s_geom\n",
    "    val tree = tree_broads.value\n",
    "    for (cur<-iter) {\n",
    "      val s_fid = cur.getAs[Long](\"id\")\n",
    "      val s_geom = cur.getAs[Point](\"geometry\")\n",
    "      val results = tree.query(s_geom.getEnvelopeInternal).asInstanceOf[java.util.List[Row]]\n",
    "\n",
    "      for (x<-results) {\n",
    "        val m_fid = x.getAs[Long](\"id\")\n",
    "        val m_envelop = x.getAs[Geometry](\"envelop\")\n",
    "        if(m_envelop.intersects(s_geom)){\n",
    "          res_list = res_list:+(s_fid,m_fid,s_geom)\n",
    "        }\n",
    "      }\n",
    "    }\n",
    "    res_list.iterator\n",
    "  }).toDF(\"s_fid\",\"m_fid\",\"s_geom\")\n",
    "\n",
    "  dfnear = dfnear.union(dfneari)\n",
    "}\n",
    "\n",
    "dfnear.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13062"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dfnear.count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================2019-11-24 18:06:40\n",
      "step3: looking for effective pairs by DNN model-> dfpair...\n",
      "\n",
      "+-----+-----+--------------------+--------------------+\n",
      "|s_fid|m_fid|              s_geom|              m_geom|\n",
      "+-----+-----+--------------------+--------------------+\n",
      "|   19|  489|POINT (-0.2040816...|POINT (-0.0796163...|\n",
      "|   19|  488|POINT (-0.2040816...|POINT (-0.0587587...|\n",
      "|   19|  465|POINT (-0.2040816...|POINT (-0.1552430...|\n",
      "|   39|  311|POINT (1.00599833...|POINT (1.16695573...|\n",
      "|   39|   64|POINT (1.00599833...|POINT (1.15057045...|\n",
      "|   39|  416|POINT (1.00599833...|POINT (0.89927494...|\n",
      "|   59|  416|POINT (0.75345168...|POINT (0.89927494...|\n",
      "|   79|  271|POINT (-0.1792310...|POINT (-0.3258320...|\n",
      "|   79|  489|POINT (-0.1792310...|POINT (-0.0796163...|\n",
      "|   79|  488|POINT (-0.1792310...|POINT (-0.0587587...|\n",
      "|   79|  465|POINT (-0.1792310...|POINT (-0.1552430...|\n",
      "|   99|  147|POINT (0.22302604...|POINT (0.33332249...|\n",
      "|   99|  283|POINT (0.22302604...|POINT (0.36557375...|\n",
      "|   99|  216|POINT (0.22302604...|POINT (0.19487138...|\n",
      "|  119|  288|POINT (1.62074019...|POINT (1.59041478...|\n",
      "|  119|  181|POINT (1.62074019...|POINT (1.50625683...|\n",
      "|  119|  450|POINT (1.62074019...|POINT (1.51993839...|\n",
      "|  139|  144|POINT (-0.9985058...|POINT (-0.9342012...|\n",
      "|  159|  372|POINT (0.14456261...|POINT (0.19495316...|\n",
      "|  159|   20|POINT (0.14456261...|POINT (0.17214778...|\n",
      "+-----+-----+--------------------+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "dfpair_raw = [s_fid: bigint, m_fid: bigint ... 2 more fields]\n",
       "dfpair = [s_fid: bigint, m_fid: bigint ... 2 more fields]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[s_fid: bigint, m_fid: bigint ... 2 more fields]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "/*================================================================================*/\n",
    "//  三，根据DBSCAN邻域半径得到有效邻近关系 dfpair\n",
    "/*================================================================================*/\n",
    "printlog(\"step3: looking for effective pairs by DNN model-> dfpair...\")\n",
    "\n",
    "\n",
    "\n",
    "val dfpair_raw = dfinput.join(dfnear, dfinput(\"id\")===dfnear(\"m_fid\"), \"right\")\n",
    "    .selectExpr(\"s_fid\",\"m_fid\",\"s_geom\",\"geometry as m_geom\")\n",
    "  \n",
    "spark.udf.register(\"distance\", (p: Point, q:Point) => p.distance(q))\n",
    "val dfpair = dfpair_raw.where(\"distance(s_geom,m_geom) < 0.2\") //邻域半径R设置为0.2\n",
    "    .persist(StorageLevel.MEMORY_AND_DISK)\n",
    "\n",
    "dfpair.show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================2019-11-24 18:06:45\n",
      "step4: looking for temporatory clusters -> dfcore ...\n",
      "\n",
      "+-----+--------------------+-------------+--------------------+\n",
      "|s_fid|              s_geom|neighbour_cnt|       neighbour_ids|\n",
      "+-----+--------------------+-------------+--------------------+\n",
      "|   26|POINT (0.95199382...|           25|[220, 460, 26, 22...|\n",
      "|   65|POINT (0.46872165...|           30|[491, 65, 258, 44...|\n",
      "|  418|POINT (0.04187413...|           22|[392, 475, 291, 4...|\n",
      "+-----+--------------------+-------------+--------------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "dfcore = [s_fid: bigint, s_geom: point ... 2 more fields]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[s_fid: bigint, s_geom: point ... 2 more fields]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "/*================================================================================*/\n",
    "//  四，创建临时聚类簇 dfcore\n",
    "/*================================================================================*/\n",
    "printlog(\"step4: looking for temporatory clusters -> dfcore ...\")\n",
    "\n",
    "\n",
    "val dfcore = dfpair.groupBy(\"s_fid\").agg(\n",
    "  first(\"s_geom\") as \"s_geom\",\n",
    "  count(\"m_fid\") as \"neighbour_cnt\",\n",
    "  collect_list(\"m_fid\") as \"neighbour_ids\"\n",
    ").where(\"neighbour_cnt>=20\")  //此处最少点数目minpoits设置为20\n",
    ".persist(StorageLevel.MEMORY_AND_DISK)\n",
    "\n",
    "dfcore.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================2019-11-24 18:09:26\n",
      "step5: get infomation for temporatory clusters -> rdd_core ...\n",
      "\n",
      "before dbscan: rdd_core.count = 358\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "dfpair_join = [s_fid: bigint, m_fid: bigint ... 2 more fields]\n",
       "df_fids = [m_fid: bigint]\n",
       "dfpair_core = [m_fid: bigint, s_fid: bigint ... 2 more fields]\n",
       "rdd_core = MapPartitionsRDD[192] at map at <console>:69\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "MapPartitionsRDD[192] at map at <console>:69"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "/*================================================================================*/\n",
    "//  五，得到临时聚类簇的核心点信息  rdd_core\n",
    "/*================================================================================*/\n",
    "printlog(\"step5: get infomation for temporatory clusters -> rdd_core ...\")\n",
    "\n",
    "val dfpair_join = dfcore.selectExpr(\"s_fid\").join(dfpair,Seq(\"s_fid\"),\"inner\")\n",
    "val df_fids = dfcore.selectExpr(\"s_fid as m_fid\")\n",
    "val dfpair_core = df_fids.join(dfpair_join,Seq(\"m_fid\"),\"inner\")\n",
    "var rdd_core = dfpair_core.groupBy(\"s_fid\").agg(\n",
    "  min(\"m_fid\") as \"min_core_id\",\n",
    "  collect_set(\"m_fid\") as \"core_id_set\"\n",
    ").rdd.map(row =>{\n",
    "  val min_core_id = row.getAs[Long](\"min_core_id\")\n",
    "  val core_id_set = row.getAs[WrappedArray[Long]](\"core_id_set\").toArray.toSet\n",
    "  (min_core_id,core_id_set)\n",
    "})\n",
    "\n",
    "println(s\"before dbscan: rdd_core.count = ${rdd_core.count}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================2019-11-24 18:09:54\n",
      "step6: run dbscan clustering ...\n",
      "\n",
      "after dbscan: rdd_core.count = 2\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "mergeSets = > scala.collection.mutable.ListBuffer[Set[Long]] = <function1>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "mergeRDD: (rdd: org.apache.spark.rdd.RDD[(Long, Set[Long])], partition_cnt: Int)org.apache.spark.rdd.RDD[(Long, Set[Long])]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "> scala.collection.mutable.ListBuffer[Set[Long]] = <function1>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "/*================================================================================*/\n",
    "//  六，对rdd_core分区分步合并  rdd_core(min_core_id, core_id_set)\n",
    "/*================================================================================*/\n",
    "printlog(\"step6: run dbscan clustering ...\")\n",
    "\n",
    "//定义合并方法\n",
    "val mergeSets = (set_list: ListBuffer[Set[Long]]) =>{\n",
    "  var result = ListBuffer[Set[Long]]()\n",
    "  while (set_list.size>0){\n",
    "    var cur_set = set_list.remove(0)\n",
    "    var intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)\n",
    "    while(intersect_idxs.size>0){\n",
    "      for(idx<-intersect_idxs){\n",
    "        cur_set = cur_set|set_list(idx)\n",
    "      }\n",
    "      for(idx<-intersect_idxs){\n",
    "        set_list.remove(idx)\n",
    "      }\n",
    "      intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)\n",
    "    }\n",
    "    result = result:+cur_set\n",
    "  }\n",
    "  result\n",
    "}\n",
    "\n",
    "//对rdd_core分区后在每个分区合并，不断将分区数量减少，最终合并到一个分区\n",
    "//如果数据规模十分大，难以合并到一个分区，也可以最终合并到多个分区，得到近似结果。\n",
    "//rdd: (min_core_id,core_id_set)\n",
    "\n",
    "def mergeRDD(rdd: org.apache.spark.rdd.RDD[(Long,Set[Long])], partition_cnt:Int):\n",
    "org.apache.spark.rdd.RDD[(Long,Set[Long])] = {\n",
    "  val rdd_merged =  rdd.partitionBy(new HashPartitioner(partition_cnt))\n",
    "    .mapPartitions(iter => {\n",
    "      val buffer = ListBuffer[Set[Long]]()\n",
    "      for(t<-iter){\n",
    "        val core_id_set:Set[Long] = t._2\n",
    "        buffer.add(core_id_set)\n",
    "      }\n",
    "      val merged_buffer = mergeSets(buffer)\n",
    "      var result = List[(Long,Set[Long])]()\n",
    "      for(core_id_set<-merged_buffer){\n",
    "        val min_core_id = core_id_set.min\n",
    "        result = result:+(min_core_id,core_id_set)\n",
    "      }\n",
    "      result.iterator\n",
    "    })\n",
    "  rdd_merged\n",
    "}\n",
    "\n",
    "\n",
    "//!此处需要调整分区数量和迭代次数\n",
    "\n",
    "for(pcnt<-Array(16,8,4,1)){\n",
    "  rdd_core = mergeRDD(rdd_core,pcnt)\n",
    "}\n",
    "\n",
    "println(s\"after dbscan: rdd_core.count = ${rdd_core.count}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================2019-11-24 18:10:04\n",
      "step7: get cluster ids ...\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "dfcluster_ids = [cluster_id: bigint, s_fid: bigint]\n",
       "dfclusters = [s_fid: bigint, s_geom: point ... 3 more fields]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[s_fid: bigint, s_geom: point ... 3 more fields]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "/*================================================================================*/\n",
    "//  七，获取每一个core的簇信息\n",
    "/*================================================================================*/\n",
    "printlog(\"step7: get cluster ids ...\")\n",
    "\n",
    "val dfcluster_ids = rdd_core.flatMap(t => {\n",
    "  val cluster_id = t._1\n",
    "  val id_set = t._2\n",
    "  for(core_id<-id_set) yield (cluster_id, core_id)\n",
    "}).toDF(\"cluster_id\",\"s_fid\")\n",
    "\n",
    "val dfclusters =  dfcore.join(dfcluster_ids, Seq(\"s_fid\"), \"left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================2019-11-24 18:10:13\n",
      "step8: evaluate cluster representation ...\n",
      "\n",
      "+----------+--------------------+--------------------+------------------+--------------------+\n",
      "|cluster_id|representation_point|neighbour_points_cnt|cluster_points_cnt|              id_set|\n",
      "+----------+--------------------+--------------------+------------------+--------------------+\n",
      "|         0|POINT (1.95163238...|                  32|               242|[365, 138, 101, 4...|\n",
      "|         2|POINT (0.95067226...|                  34|               241|[69, 347, 468, 35...|\n",
      "+----------+--------------------+--------------------+------------------+--------------------+\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "rdd_cluster = MapPartitionsRDD[215] at map at <console>:59\n",
       "rdd_result = ShuffledRDD[216] at reduceByKey at <console>:67\n",
       "dfresult = [cluster_id: bigint, representation_point: point ... 3 more fields]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[cluster_id: bigint, representation_point: point ... 3 more fields]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "/*================================================================================*/\n",
    "//  八，求每一个簇的代表核心和簇元素数量\n",
    "/*================================================================================*/\n",
    "printlog(\"step8: evaluate cluster representation ...\")\n",
    "\n",
    "val rdd_cluster = dfclusters.rdd.map(row=> {\n",
    "  val cluster_id = row.getAs[Long](\"cluster_id\")\n",
    "  val s_geom = row.getAs[Point](\"s_geom\")\n",
    "  val neighbour_cnt = row.getAs[Long](\"neighbour_cnt\")\n",
    "  val id_set = row.getAs[WrappedArray[Long]](\"neighbour_ids\").toSet\n",
    "  (cluster_id,(s_geom,neighbour_cnt,id_set))\n",
    "})\n",
    "\n",
    "val rdd_result = rdd_cluster.reduceByKey((a,b)=>{\n",
    "  val id_set = a._3 | b._3\n",
    "  val result = if(a._2>=b._2) (a._1,a._2,id_set)\n",
    "  else (b._1,b._2,id_set)\n",
    "  result\n",
    "})\n",
    "\n",
    "val dfresult = rdd_result.map(t=>{\n",
    "  val cluster_id = t._1\n",
    "  val representation_point = t._2._1\n",
    "  val neighbour_points_cnt = t._2._2\n",
    "  val id_set = t._2._3\n",
    "  val cluster_points_cnt = id_set.size\n",
    "  (cluster_id,representation_point,neighbour_points_cnt,cluster_points_cnt,id_set)\n",
    "}).toDF(\"cluster_id\",\"representation_point\",\"neighbour_points_cnt\",\"cluster_points_cnt\",\"id_set\")\n",
    "\n",
    "dfresult.show(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意到我们的结果中\n",
    "\n",
    "聚类簇数量为2个。\n",
    "\n",
    "噪声点数量为500-242-241 = 17个\n",
    "\n",
    "和调用sklearn中的结果完全一致。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](data/sklearn的DBSCAN聚类结果.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Spark - Scala",
   "language": "scala",
   "name": "spark_scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".scala",
   "mimetype": "text/x-scala",
   "name": "scala",
   "pygments_lexer": "scala",
   "version": "2.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
